{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 3: Convolutional Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CNNNet (or AlexNet by another name…)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNNNet(nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes=2):\n",
    "        super(CNNNet, self).__init__()\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        )\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Dropout(),\n",
    "            nn.Linear(256 * 6 * 6, 4096),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(),\n",
    "            nn.Linear(4096, 4096),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(4096, num_classes)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.classifier(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnnnet = CNNNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device=\"cpu\"):\n",
    "    for epoch in range(epochs):\n",
    "        training_loss = 0.0\n",
    "        valid_loss = 0.0\n",
    "        model.train()\n",
    "        for batch in train_loader:\n",
    "            optimizer.zero_grad()\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            output = model(inputs)\n",
    "            loss = loss_fn(output, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            training_loss += loss.data.item() * inputs.size(0)\n",
    "        training_loss /= len(train_loader.dataset)\n",
    "        \n",
    "        model.eval()\n",
    "        num_correct = 0 \n",
    "        num_examples = 0\n",
    "        for batch in val_loader:\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            output = model(inputs)\n",
    "            targets = targets.to(device)\n",
    "            loss = loss_fn(output,targets) \n",
    "            valid_loss += loss.data.item() * inputs.size(0)\n",
    "            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)\n",
    "            num_correct += torch.sum(correct).item()\n",
    "            num_examples += correct.shape[0]\n",
    "        valid_loss /= len(val_loader.dataset)\n",
    "\n",
    "        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,\n",
    "        valid_loss, num_correct / num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_image(path):\n",
    "    try:\n",
    "        im = Image.open(path)\n",
    "        return True\n",
    "    except:\n",
    "        return False\n",
    "img_transforms = transforms.Compose([\n",
    "    transforms.Resize((64,64)),    \n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                    std=[0.229, 0.224, 0.225] )\n",
    "    ])\n",
    "train_data_path = \"./train/\"\n",
    "train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)\n",
    "val_data_path = \"./val/\"\n",
    "val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms, is_valid_file=check_image)\n",
    "batch_size=64\n",
    "train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,shuffle=True)\n",
    "val_data_loader  = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\") \n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnnnet.to(device)\n",
    "optimizer = optim.Adam(cnnnet.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(cnnnet, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=10, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Downloading a pretrained network \n",
    "\n",
    "There are two ways of downloading pre-trained image models with PyTorch. Firstly, you can use the `torchvision.models` library, or you can use PyTorch Hub. The latter is preferred as of 2019, as this is a one-stop shop for all models and the new standard for distributing models with PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "alexnet = models.alexnet(num_classes=1000, pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet50 = torch.hub.load('pytorch/vision', 'resnet50')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(alexnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(resnet50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
